{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jMGcXXPabEN4"
      },
      "source": [
        "# Uni-Fold Colab\n",
        "\n",
        "This Colab notebook provides an online runnable version of [Uni-Fold](https://github.com/dptech-corp/Uni-Fold/) for users to predict the structure of a protein, single chain or multimer, with custom settings.\n",
        "\n",
        "Thanks to  [MMSeqs2](https://github.com/soedinglab/MMseqs2.git) and the server provided by [ColabFold](https://github.com/sokrypton/ColabFold), the homogeneous searching in this notebook is very fast and is comparable with the original AlphaFold(-Multimer). If you want more consistent results with the original AlphaFold(-Multimer), you can use the [full open source Uni-Fold](https://github.com/dptech-corp/Uni-Fold/), or the convenient web server at [Hermite™](https://hermite.dp.tech/).\n",
        "\n",
        "Please note that this Colab notebook is not a finished product and is provided as an early-access prototype. It is provided for theoretical modeling only and caution should be exercised in its use. \n",
        "\n",
        "**Licenses**\n",
        "\n",
        "This Colab uses the [Uni-Fold model parameters](https://github.com/dptech-corp/Uni-Fold/#model-parameters-license) and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).\n",
        "\n",
        "\n",
        "**Citations**\n",
        "\n",
        "Please cite the following papers if you use this notebook:\n",
        "\n",
        "*   Jumper et al. \"[Highly accurate protein structure prediction with AlphaFold.](https://doi.org/10.1038/s41586-021-03819-2)\" Nature (2021)\n",
        "*   Evans et al. \"[Protein complex prediction with AlphaFold-Multimer.](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1)\" biorxiv (2021)\n",
        "*   Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. \"[ColabFold: Making protein folding accessible to all.](https://www.nature.com/articles/s41592-022-01488-1)\" Nature Methods (2022) \n",
        "*   Ziyao Li, Xuyang Liu, Weijie Chen, Fan Shen, Hangrui Bi, Guolin Ke, Linfeng Zhang. \"[Uni-Fold: An Open-Source Platform for Developing Protein Folding Models beyond AlphaFold.](https://www.biorxiv.org/content/10.1101/2022.08.04.502811v1)\" biorxiv (2022)\n",
        "\n",
        "\n",
        "**Acknowledgements**\n",
        "\n",
        "We thank [@sokrypton](https://twitter.com/sokrypton) for many helpful suggestions to this notebook.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "y0Evc150bEN7"
      },
      "outputs": [],
      "source": [
        "#@title Install third-party software\n",
        "#@markdown Please execute this cell by pressing the _Play_ button \n",
        "#@markdown on the left to download and import third-party software \n",
        "#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/dptech-corp/Uni-Fold/#acknowledgements) in our readme.)\n",
        "\n",
        "#@markdown **Note**: This installs the software on the Colab \n",
        "#@markdown notebook in the cloud and not on your computer.\n",
        "\n",
        "!apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y -qq \\\n",
        "    hmmer \\\n",
        "    kalign\n",
        "\n",
        "# Install HHsuite.\n",
        "!wget -q https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-AVX2-Linux.tar.gz; tar xfz hhsuite-3.3.0-AVX2-Linux.tar.gz; ln -s $(pwd)/bin/* /usr/bin \n",
        "\n",
        "!pip3 -q install py3dmol gdown"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "rETqvokYbEN9"
      },
      "outputs": [],
      "source": [
        "#@title Download Uni-Fold\n",
        "\n",
        "#@markdown Please execute this cell by pressing the *Play* button on \n",
        "#@markdown the left.\n",
        "\n",
        "GIT_REPO = 'https://github.com/dptech-corp/Uni-Fold'\n",
        "UNICORE_URL = 'https://github.com/dptech-corp/Uni-Core/releases/download/0.0.1/unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl'\n",
        "PARAM_URL = 'https://drive.google.com/uc?id=1A9iXMYCwP0f_U0FgISJ_6BX7FXZtglvV'\n",
        "\n",
        "!rm *.whl\n",
        "!wget  {UNICORE_URL} \n",
        "!pip3 -q install \"unicore-0.0.1+cu113torch1.12.1-cp37-cp37m-linux_x86_64.whl\"\n",
        "!rm -rf ./Uni-Fold\n",
        "!git clone -b main {GIT_REPO}\n",
        "!pip3 -q install ./Uni-Fold\n",
        "!gdown {PARAM_URL}\n",
        "!tar -xzf \"unifold_params_2022-08-01.tar.gz\"\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "j-xTD0QubEN-"
      },
      "outputs": [],
      "source": [
        "#@title Input protein sequence(s), then hit `Runtime` -> `Run all`\n",
        "import os\n",
        "import re\n",
        "import hashlib\n",
        "import random\n",
        "import numpy as np\n",
        "from pathlib import Path\n",
        "from typing import Dict, List, Sequence, Tuple, Union, Any, Optional\n",
        "\n",
        "from unifold.data import residue_constants, protein\n",
        "from unifold.msa.utils import divide_multi_chains\n",
        "\n",
        "MIN_SINGLE_SEQUENCE_LENGTH = 16\n",
        "MAX_SINGLE_SEQUENCE_LENGTH = 1000\n",
        "MAX_MULTIMER_LENGTH = 1000\n",
        "\n",
        "output_dir_base = \"./prediction\"\n",
        "os.makedirs(output_dir_base, exist_ok=True)\n",
        "\n",
        "\n",
        "def clean_and_validate_sequence(\n",
        "    input_sequence: str, min_length: int, max_length: int) -> str:\n",
        "  \"\"\"Checks that the input sequence is ok and returns a clean version of it.\"\"\"\n",
        "  # Remove all whitespaces, tabs and end lines; upper-case.\n",
        "  clean_sequence = input_sequence.translate(\n",
        "      str.maketrans('', '', ' \\n\\t')).upper()\n",
        "  aatypes = set(residue_constants.restypes)  # 20 standard aatypes.\n",
        "  if not set(clean_sequence).issubset(aatypes):\n",
        "    raise ValueError(\n",
        "        f'Input sequence contains non-amino acid letters: '\n",
        "        f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '\n",
        "        'amino acids as inputs.')\n",
        "  if len(clean_sequence) < min_length:\n",
        "    raise ValueError(\n",
        "        f'Input sequence is too short: {len(clean_sequence)} amino acids, '\n",
        "        f'while the minimum is {min_length}')\n",
        "  if len(clean_sequence) > max_length:\n",
        "    raise ValueError(\n",
        "        f'Input sequence is too long: {len(clean_sequence)} amino acids, while '\n",
        "        f'the maximum is {max_length}. You may be able to run it with the full '\n",
        "        f'Uni-Fold system depending on your resources (system memory, '\n",
        "        f'GPU memory).')\n",
        "  return clean_sequence\n",
        "\n",
        "\n",
        "def validate_input(\n",
        "    input_sequences: Sequence[str],\n",
        "    min_length: int,\n",
        "    max_length: int,\n",
        "    max_multimer_length: int) -> Tuple[Sequence[str], bool]:\n",
        "  \"\"\"Validates and cleans input sequences and determines which model to use.\"\"\"\n",
        "  sequences = []\n",
        "\n",
        "  for input_sequence in input_sequences:\n",
        "    if input_sequence.strip():\n",
        "      input_sequence = clean_and_validate_sequence(\n",
        "          input_sequence=input_sequence,\n",
        "          min_length=min_length,\n",
        "          max_length=max_length)\n",
        "      sequences.append(input_sequence)\n",
        "\n",
        "  if len(sequences) == 1:\n",
        "    print('Using the single-chain model.')\n",
        "    return sequences, False\n",
        "\n",
        "  elif len(sequences) > 1:\n",
        "    total_multimer_length = sum([len(seq) for seq in sequences])\n",
        "    if total_multimer_length > max_multimer_length:\n",
        "      raise ValueError(f'The total length of multimer sequences is too long: '\n",
        "                       f'{total_multimer_length}, while the maximum is '\n",
        "                       f'{max_multimer_length}. Please use the full AlphaFold '\n",
        "                       f'system for long multimers.')\n",
        "    print(f'Using the multimer model with {len(sequences)} sequences.')\n",
        "    return sequences, True\n",
        "\n",
        "  else:\n",
        "    raise ValueError('No input amino acid sequence provided, please provide at '\n",
        "                     'least one sequence.')\n",
        "\n",
        "def add_hash(x,y):\n",
        "    return x+\"_\"+hashlib.sha1(y.encode()).hexdigest()[:5]\n",
        "\n",
        "\n",
        "sequence_1 = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'  #@param {type:\"string\"}\n",
        "sequence_2 = ''  #@param {type:\"string\"}\n",
        "sequence_3 = ''  #@param {type:\"string\"}\n",
        "sequence_4 = ''  #@param {type:\"string\"}\n",
        "\n",
        "use_templates = True #@param {type:\"boolean\"}\n",
        "msa_mode = \"MMseqs2\" #@param [\"MMseqs2\",\"single_sequence\"]\n",
        "\n",
        "input_sequences = [sequence_1, sequence_2, sequence_3, sequence_4]\n",
        "\n",
        "jobname = 'unifold_colab' #@param {type:\"string\"}\n",
        "\n",
        "basejobname = \"\".join(input_sequences)\n",
        "basejobname = re.sub(r'\\W+', '', basejobname)\n",
        "target_id = add_hash(jobname, basejobname)\n",
        "\n",
        "# Validate the input.\n",
        "sequences, is_multimer = validate_input(\n",
        "    input_sequences=input_sequences,\n",
        "    min_length=MIN_SINGLE_SEQUENCE_LENGTH,\n",
        "    max_length=MAX_SINGLE_SEQUENCE_LENGTH,\n",
        "    max_multimer_length=MAX_MULTIMER_LENGTH)\n",
        "\n",
        "descriptions = ['> '+target_id+' seq'+str(ii) for ii in range(len(sequences))]\n",
        "\n",
        "if is_multimer:\n",
        "    divide_multi_chains(target_id, output_dir_base, sequences, descriptions)\n",
        "    \n",
        "s = []\n",
        "for des, seq in zip(descriptions, sequences):\n",
        "    s += [des, seq]\n",
        "\n",
        "unique_sequences = []\n",
        "[unique_sequences.append(x) for x in sequences if x not in unique_sequences]\n",
        "\n",
        "if len(unique_sequences)==1:\n",
        "    homooligomers_num = len(sequences)\n",
        "else:\n",
        "    homooligomers_num = 1\n",
        "    \n",
        "with open(f\"{jobname}.fasta\", \"w\") as f:\n",
        "    f.write(\"\\n\".join(s))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "QThPtPvlbEN_"
      },
      "outputs": [],
      "source": [
        "#@title Generate homogeneous features via ColabFold-MMSeqs2 server\n",
        "#@markdown Acknowledge to [ColabFold](https://github.com/sokrypton/ColabFold.git)\n",
        "\n",
        "import tarfile\n",
        "import requests\n",
        "from tqdm import tqdm\n",
        "import time\n",
        "import logging\n",
        "\n",
        "from unifold.msa import templates, pipeline\n",
        "from unifold.msa.tools import hhsearch\n",
        "\n",
        "\n",
        "logger = logging.getLogger(__name__)\n",
        "\n",
        "TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'\n",
        "DEFAULT_API_SERVER = \"https://api.colabfold.com\"\n",
        "\n",
        "def run_mmseqs2(x, prefix, use_env=True, \n",
        "                use_templates=False, use_pairing=False,\n",
        "                host_url=\"https://api.colabfold.com\") -> Tuple[List[str], List[str]]:\n",
        "  submission_endpoint = \"ticket/pair\" if use_pairing else \"ticket/msa\"\n",
        "\n",
        "  def submit(seqs, mode, N=101):\n",
        "    n, query = N, \"\"\n",
        "    for seq in seqs:\n",
        "      query += f\">{n}\\n{seq}\\n\"\n",
        "      n += 1\n",
        "\n",
        "    res = requests.post(f'{host_url}/{submission_endpoint}', data={'q':query,'mode': mode})\n",
        "    try:\n",
        "      out = res.json()\n",
        "    except ValueError:\n",
        "      logger.error(f\"Server didn't reply with json: {res.text}\")\n",
        "      out = {\"status\":\"ERROR\"}\n",
        "    return out\n",
        "\n",
        "  def status(ID):\n",
        "    res = requests.get(f'{host_url}/ticket/{ID}')\n",
        "    try:\n",
        "      out = res.json()\n",
        "    except ValueError:\n",
        "      logger.error(f\"Server didn't reply with json: {res.text}\")\n",
        "      out = {\"status\":\"ERROR\"}\n",
        "    return out\n",
        "\n",
        "  def download(ID, path):\n",
        "    res = requests.get(f'{host_url}/result/download/{ID}')\n",
        "    with open(path,\"wb\") as out: out.write(res.content)\n",
        "\n",
        "  # process input x\n",
        "  seqs = [x] if isinstance(x, str) else x\n",
        "\n",
        "  mode = \"env\"\n",
        "  if use_pairing:\n",
        "    mode = \"\"\n",
        "    use_templates = False\n",
        "    use_env = False\n",
        "\n",
        "  # define path\n",
        "  path = f\"{prefix}\"\n",
        "  if not os.path.isdir(path): os.mkdir(path)\n",
        "\n",
        "  # call mmseqs2 api\n",
        "  tar_gz_file = f'{path}/out_{mode}.tar.gz'\n",
        "  N,REDO = 101,True\n",
        "\n",
        "  # deduplicate and keep track of order\n",
        "  seqs_unique = []\n",
        "  #TODO this might be slow for large sets\n",
        "  [seqs_unique.append(x) for x in seqs if x not in seqs_unique]\n",
        "  Ms = [N + seqs_unique.index(seq) for seq in seqs]\n",
        "  # lets do it!\n",
        "  if not os.path.isfile(tar_gz_file):\n",
        "    TIME_ESTIMATE = 150 * len(seqs_unique)\n",
        "    with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:\n",
        "      while REDO:\n",
        "        pbar.set_description(\"SUBMIT\")\n",
        "\n",
        "        # Resubmit job until it goes through\n",
        "        out = submit(seqs_unique, mode, N)\n",
        "        while out[\"status\"] in [\"UNKNOWN\", \"RATELIMIT\"]:\n",
        "          sleep_time = 5 + random.randint(0, 5)\n",
        "          logger.error(f\"Sleeping for {sleep_time}s. Reason: {out['status']}\")\n",
        "          # resubmit\n",
        "          time.sleep(sleep_time)\n",
        "          out = submit(seqs_unique, mode, N)\n",
        "\n",
        "        if out[\"status\"] == \"ERROR\":\n",
        "          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')\n",
        "\n",
        "        if out[\"status\"] == \"MAINTENANCE\":\n",
        "          raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')\n",
        "\n",
        "        # wait for job to finish\n",
        "        ID,TIME = out[\"id\"],0\n",
        "        pbar.set_description(out[\"status\"])\n",
        "        while out[\"status\"] in [\"UNKNOWN\",\"RUNNING\",\"PENDING\"]:\n",
        "          t = 5 + random.randint(0,5)\n",
        "          logger.error(f\"Sleeping for {t}s. Reason: {out['status']}\")\n",
        "          time.sleep(t)\n",
        "          out = status(ID)\n",
        "          pbar.set_description(out[\"status\"])\n",
        "          if out[\"status\"] == \"RUNNING\":\n",
        "            TIME += t\n",
        "            pbar.update(n=t)\n",
        "\n",
        "        if out[\"status\"] == \"COMPLETE\":\n",
        "          if TIME < TIME_ESTIMATE:\n",
        "            pbar.update(n=(TIME_ESTIMATE-TIME))\n",
        "          REDO = False\n",
        "\n",
        "        if out[\"status\"] == \"ERROR\":\n",
        "          REDO = False\n",
        "          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')\n",
        "\n",
        "      # Download results\n",
        "      download(ID, tar_gz_file)\n",
        "\n",
        "  # prep list of a3m files\n",
        "  if use_pairing:\n",
        "    a3m_files = [f\"{path}/pair.a3m\"]\n",
        "  else:\n",
        "    a3m_files = [f\"{path}/uniref.a3m\"]\n",
        "    if use_env: a3m_files.append(f\"{path}/bfd.mgnify30.metaeuk30.smag30.a3m\")\n",
        "\n",
        "  # extract a3m files\n",
        "  if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):\n",
        "    with tarfile.open(tar_gz_file) as tar_gz:\n",
        "      tar_gz.extractall(path)\n",
        "\n",
        "  # templates\n",
        "  if use_templates:\n",
        "    templates = {}\n",
        "\n",
        "    for line in open(f\"{path}/pdb70.m8\",\"r\"):\n",
        "      p = line.rstrip().split()\n",
        "      M,pdb,qid,e_value = p[0],p[1],p[2],p[10]\n",
        "      M = int(M)\n",
        "      if M not in templates: templates[M] = []\n",
        "      templates[M].append(pdb)\n",
        "\n",
        "    template_paths = {}\n",
        "    for k,TMPL in templates.items():\n",
        "      TMPL_PATH = f\"{prefix}/templates_{k}\"\n",
        "      if not os.path.isdir(TMPL_PATH):\n",
        "        os.mkdir(TMPL_PATH)\n",
        "        TMPL_LINE = \",\".join(TMPL[:20])\n",
        "        os.system(f\"curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/\")\n",
        "        os.system(f\"cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex\")\n",
        "        os.system(f\"touch {TMPL_PATH}/pdb70_cs219.ffdata\")\n",
        "      template_paths[k] = TMPL_PATH\n",
        "\n",
        "  # gather a3m lines\n",
        "  a3m_lines = {}\n",
        "  for a3m_file in a3m_files:\n",
        "    update_M,M = True,None\n",
        "    for line in open(a3m_file,\"r\"):\n",
        "      if len(line) > 0:\n",
        "        if \"\\x00\" in line:\n",
        "          line = line.replace(\"\\x00\",\"\")\n",
        "          update_M = True\n",
        "        if line.startswith(\">\") and update_M:\n",
        "          M = int(line[1:].rstrip())\n",
        "          update_M = False\n",
        "          if M not in a3m_lines: a3m_lines[M] = []\n",
        "        a3m_lines[M].append(line)\n",
        "\n",
        "  # return results\n",
        "\n",
        "  a3m_lines = [\"\".join(a3m_lines[n]) for n in Ms]\n",
        "\n",
        "  if use_templates:\n",
        "    template_paths_ = []\n",
        "    for n in Ms:\n",
        "      if n not in template_paths:\n",
        "        template_paths_.append(None)\n",
        "        #print(f\"{n-N}\\tno_templates_found\")\n",
        "      else:\n",
        "        template_paths_.append(template_paths[n])\n",
        "    template_paths = template_paths_\n",
        "\n",
        "\n",
        "  return (a3m_lines, template_paths) if use_templates else a3m_lines\n",
        "\n",
        "def get_null_template(\n",
        "    query_sequence: Union[List[str], str], num_temp: int = 1\n",
        ") -> Dict[str, Any]:\n",
        "    ln = (\n",
        "        len(query_sequence)\n",
        "        if isinstance(query_sequence, str)\n",
        "        else sum(len(s) for s in query_sequence)\n",
        "    )\n",
        "    output_templates_sequence = \"A\" * ln\n",
        "    output_confidence_scores = np.full(ln, 1.0)\n",
        "\n",
        "    templates_all_atom_positions = np.zeros(\n",
        "        (ln, templates.residue_constants.atom_type_num, 3)\n",
        "    )\n",
        "    templates_all_atom_masks = np.zeros((ln, templates.residue_constants.atom_type_num))\n",
        "    templates_aatype = templates.residue_constants.sequence_to_onehot(\n",
        "        output_templates_sequence, templates.residue_constants.HHBLITS_AA_TO_ID\n",
        "    )\n",
        "    template_features = {\n",
        "        \"template_all_atom_positions\": np.tile(\n",
        "            templates_all_atom_positions[None], [num_temp, 1, 1, 1]\n",
        "        ),\n",
        "        \"template_all_atom_masks\": np.tile(\n",
        "            templates_all_atom_masks[None], [num_temp, 1, 1]\n",
        "        ),\n",
        "        \"template_sequence\": [f\"none\".encode()] * num_temp,\n",
        "        \"template_aatype\": np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),\n",
        "        \"template_domain_names\": [f\"none\".encode()] * num_temp,\n",
        "        \"template_sum_probs\": np.zeros([num_temp], dtype=np.float32),\n",
        "    }\n",
        "    return template_features\n",
        "\n",
        "\n",
        "def get_template(\n",
        "    a3m_lines: str, template_path: str, query_sequence: str\n",
        ") -> Dict[str, Any]:\n",
        "    template_featurizer = templates.HhsearchHitFeaturizer(\n",
        "        mmcif_dir=template_path,\n",
        "        max_template_date=\"2100-01-01\",\n",
        "        max_hits=20,\n",
        "        kalign_binary_path=\"kalign\",\n",
        "        release_dates_path=None,\n",
        "        obsolete_pdbs_path=None,\n",
        "    )\n",
        "\n",
        "    hhsearch_pdb70_runner = hhsearch.HHSearch(\n",
        "        binary_path=\"hhsearch\", databases=[f\"{template_path}/pdb70\"]\n",
        "    )\n",
        "\n",
        "    hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)\n",
        "    hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)\n",
        "    templates_result = template_featurizer.get_templates(\n",
        "        query_sequence=query_sequence, hits=hhsearch_hits\n",
        "    )\n",
        "    return dict(templates_result.features)\n",
        "  \n",
        "def get_msa_and_templates(\n",
        "    jobname: str,\n",
        "    query_seqs_unique: Union[str, List[str]],\n",
        "    result_dir: Path,\n",
        "    msa_mode: str,\n",
        "    use_templates: bool,\n",
        "    homooligomers_num: int = 1,\n",
        "    host_url: str = DEFAULT_API_SERVER,\n",
        ") -> Tuple[\n",
        "    Optional[List[str]], Optional[List[str]], List[str], List[int], List[Dict[str, Any]]\n",
        "]:\n",
        "    \n",
        "    use_env = msa_mode == \"MMseqs2\"\n",
        "\n",
        "    template_features = []\n",
        "    if use_templates:\n",
        "        a3m_lines_mmseqs2, template_paths = run_mmseqs2(\n",
        "            query_seqs_unique,\n",
        "            str(result_dir.joinpath(jobname)),\n",
        "            use_env,\n",
        "            use_templates=True,\n",
        "            host_url=host_url,\n",
        "        )\n",
        "        if template_paths is None:\n",
        "            logger.info(\"No template detected\")\n",
        "            for index in range(0, len(query_seqs_unique)):\n",
        "                template_feature = get_null_template(query_seqs_unique[index])\n",
        "                template_features.append(template_feature)\n",
        "        else:\n",
        "            for index in range(0, len(query_seqs_unique)):\n",
        "                if template_paths[index] is not None:\n",
        "                    template_feature = get_template(\n",
        "                        a3m_lines_mmseqs2[index],\n",
        "                        template_paths[index],\n",
        "                        query_seqs_unique[index],\n",
        "                    )\n",
        "                    if len(template_feature[\"template_domain_names\"]) == 0:\n",
        "                        template_feature = get_null_template(query_seqs_unique[index])\n",
        "                        logger.info(f\"Sequence {index} found no templates\")\n",
        "                    else:\n",
        "                        logger.info(\n",
        "                            f\"Sequence {index} found templates: {template_feature['template_domain_names'].astype(str).tolist()}\"\n",
        "                        )\n",
        "                else:\n",
        "                    template_feature = get_null_template(query_seqs_unique[index])\n",
        "                    logger.info(f\"Sequence {index} found no templates\")\n",
        "\n",
        "                template_features.append(template_feature)\n",
        "    else:\n",
        "        for index in range(0, len(query_seqs_unique)):\n",
        "            template_feature = get_null_template(query_seqs_unique[index])\n",
        "            template_features.append(template_feature)\n",
        "\n",
        "\n",
        "    if msa_mode == \"single_sequence\":\n",
        "        a3m_lines = []\n",
        "        num = 101\n",
        "        for i, seq in enumerate(query_seqs_unique):\n",
        "            a3m_lines.append(\">\" + str(num + i) + \"\\n\" + seq)\n",
        "    else:\n",
        "        # find normal a3ms\n",
        "        a3m_lines = run_mmseqs2(\n",
        "            query_seqs_unique,\n",
        "            str(result_dir.joinpath(jobname)),\n",
        "            use_env,\n",
        "            use_pairing=False,\n",
        "            host_url=host_url,\n",
        "        )\n",
        "    if len(query_seqs_unique)>1:\n",
        "        # find paired a3m if not a homooligomers\n",
        "        paired_a3m_lines = run_mmseqs2(\n",
        "            query_seqs_unique,\n",
        "            str(result_dir.joinpath(jobname)),\n",
        "            use_env,\n",
        "            use_pairing=True,\n",
        "            host_url=host_url,\n",
        "        )\n",
        "    else:\n",
        "        num = 101\n",
        "        paired_a3m_lines = []\n",
        "        for i in range(0, homooligomers_num):\n",
        "            paired_a3m_lines.append(\n",
        "                \">\" + str(num + i) + \"\\n\" + query_seqs_unique[0] + \"\\n\"\n",
        "            )\n",
        "\n",
        "    return (\n",
        "        a3m_lines,\n",
        "        paired_a3m_lines,\n",
        "        template_features,\n",
        "    )\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "RWwTgjo4bEOB"
      },
      "outputs": [],
      "source": [
        "#@title Process features for Uni-Fold prediction\n",
        "import pickle\n",
        "import gzip\n",
        "from unifold.msa import parsers\n",
        "from unifold.msa import pipeline\n",
        "from unifold.data.utils import compress_features\n",
        "from unifold.data.protein import PDB_CHAIN_IDS\n",
        "\n",
        "result_dir = Path(output_dir_base)\n",
        "output_dir = os.path.join(output_dir_base, target_id)\n",
        "\n",
        "(\n",
        "  unpaired_msa,\n",
        "  paired_msa,\n",
        "  template_results,\n",
        ") = get_msa_and_templates(\n",
        "  target_id,\n",
        "  unique_sequences,\n",
        "  result_dir=result_dir,\n",
        "  msa_mode=msa_mode,\n",
        "  use_templates=use_templates,\n",
        "  homooligomers_num = homooligomers_num\n",
        ")\n",
        "\n",
        "\n",
        "for idx, seq in enumerate(unique_sequences):\n",
        "    chain_id = PDB_CHAIN_IDS[idx]\n",
        "    sequence_features = pipeline.make_sequence_features(\n",
        "              sequence=seq, description=f'> {jobname} seq {chain_id}', num_res=len(seq)\n",
        "          )\n",
        "    monomer_msa = parsers.parse_a3m(unpaired_msa[idx])\n",
        "    msa_features = pipeline.make_msa_features([monomer_msa])\n",
        "    template_features = template_results[idx]\n",
        "    feature_dict = {**sequence_features, **msa_features, **template_features}\n",
        "    feature_dict = compress_features(feature_dict)\n",
        "    features_output_path = os.path.join(\n",
        "            output_dir, \"{}.feature.pkl.gz\".format(chain_id)\n",
        "        )\n",
        "    pickle.dump(\n",
        "        feature_dict, \n",
        "        gzip.GzipFile(features_output_path, \"wb\"), \n",
        "        protocol=4\n",
        "        )\n",
        "    if is_multimer:\n",
        "        multimer_msa = parsers.parse_a3m(paired_msa[idx])\n",
        "        pair_features = pipeline.make_msa_features([multimer_msa])\n",
        "        pair_feature_dict = compress_features(pair_features)\n",
        "        uniprot_output_path = os.path.join(\n",
        "            output_dir, \"{}.uniprot.pkl.gz\".format(chain_id)\n",
        "        )\n",
        "        pickle.dump(\n",
        "            pair_feature_dict,\n",
        "            gzip.GzipFile(uniprot_output_path, \"wb\"),\n",
        "            protocol=4,\n",
        "        )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "RJUxaO7Ofw1L"
      },
      "outputs": [],
      "source": [
        "#@title Uni-Fold prediction\n",
        "\n",
        "import torch\n",
        "import json\n",
        "from unifold.config import model_config\n",
        "from unifold.modules.alphafold import AlphaFold\n",
        "from unifold.dataset import load_and_process, UnifoldDataset\n",
        "from unicore.utils import (\n",
        "    tensor_tree_map,\n",
        ")\n",
        "\n",
        "def automatic_chunk_size(seq_len):\n",
        "    if seq_len < 512:\n",
        "        chunk_size = 256\n",
        "    elif seq_len < 1024:\n",
        "        chunk_size = 128\n",
        "    elif seq_len < 2048:\n",
        "        chunk_size = 32\n",
        "    elif seq_len < 3072:\n",
        "        chunk_size = 16\n",
        "    else:\n",
        "        chunk_size = 1\n",
        "    return chunk_size\n",
        "\n",
        "\n",
        "def load_feature_for_one_target(\n",
        "    config, data_folder, seed=0, is_multimer=False, use_uniprot=False\n",
        "):\n",
        "    if not is_multimer:\n",
        "        uniprot_msa_dir = None\n",
        "        sequence_ids = [\"A\"]\n",
        "        if use_uniprot:\n",
        "            uniprot_msa_dir = data_folder\n",
        "\n",
        "    else:\n",
        "        uniprot_msa_dir = data_folder\n",
        "        sequence_ids = open(os.path.join(data_folder, \"chains.txt\")).readline().split()\n",
        "    batch, _ = load_and_process(\n",
        "        config=config.data,\n",
        "        mode=\"predict\",\n",
        "        seed=seed,\n",
        "        batch_idx=None,\n",
        "        data_idx=0,\n",
        "        is_distillation=False,\n",
        "        sequence_ids=sequence_ids,\n",
        "        monomer_feature_dir=data_folder,\n",
        "        uniprot_msa_dir=uniprot_msa_dir,\n",
        "    )\n",
        "    batch = UnifoldDataset.collater([batch])\n",
        "    return batch\n",
        "\n",
        "if is_multimer:\n",
        "    model_name = \"multimer_ft\"\n",
        "    param_path = \"./multimer.unifold.pt\"\n",
        "else:\n",
        "    model_name = \"model_2_ft\"\n",
        "    param_path = \"./monomer.unifold.pt\"\n",
        "\n",
        "max_recycling_iters = 3 #@param {type:\"integer\"}\n",
        "num_ensembles = 2 #@param {type:\"integer\"}\n",
        "manual_seed = 42 #@param {type:\"integer\"}\n",
        "times = 3 #@param {type:\"integer\"}\n",
        "\n",
        "config = model_config(model_name)\n",
        "config.data.common.max_recycling_iters = max_recycling_iters\n",
        "config.globals.max_recycling_iters = max_recycling_iters\n",
        "config.data.predict.num_ensembles = num_ensembles\n",
        "\n",
        "# faster prediction with large chunk\n",
        "config.globals.chunk_size = 128\n",
        "model = AlphaFold(config)\n",
        "print(\"start to load params {}\".format(param_path))\n",
        "state_dict = torch.load(param_path)[\"ema\"][\"params\"]\n",
        "state_dict = {\".\".join(k.split(\".\")[1:]): v for k, v in state_dict.items()}\n",
        "model.load_state_dict(state_dict)\n",
        "model = model.to(\"cuda:0\")\n",
        "model.eval()\n",
        "\n",
        "# data path is based on target_name\n",
        "cur_param_path_postfix = os.path.split(param_path)[-1]\n",
        "\n",
        "print(\"start to predict {}\".format(target_id))\n",
        "plddts = {}\n",
        "ptms = {}\n",
        "best_protein = None\n",
        "best_score = 0\n",
        "best_plddt = None\n",
        "best_pae = None\n",
        "\n",
        "for seed in range(times):\n",
        "    cur_seed = hash((manual_seed, seed)) % 100000\n",
        "    batch = load_feature_for_one_target(\n",
        "        config,\n",
        "        output_dir,\n",
        "        cur_seed,\n",
        "        is_multimer=is_multimer,\n",
        "        use_uniprot=is_multimer,\n",
        "    )\n",
        "    seq_len = batch[\"aatype\"].shape[-1]\n",
        "    model.globals.chunk_size = automatic_chunk_size(seq_len)\n",
        "\n",
        "    with torch.no_grad():\n",
        "        batch = {\n",
        "            k: torch.as_tensor(v, device=\"cuda:0\")\n",
        "            for k, v in batch.items()\n",
        "        }\n",
        "        shapes = {k: v.shape for k, v in batch.items()}\n",
        "        print(shapes)\n",
        "        t = time.perf_counter()\n",
        "        out = model(batch)\n",
        "        print(f\"Inference time: {time.perf_counter() - t}\")\n",
        "\n",
        "    def to_float(x):\n",
        "        if x.dtype == torch.bfloat16 or x.dtype == torch.half:\n",
        "            return x.float()\n",
        "        else:\n",
        "            return x\n",
        "\n",
        "    # Toss out the recycling dimensions --- we don't need them anymore\n",
        "    batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)\n",
        "    batch = tensor_tree_map(to_float, batch)\n",
        "    out = tensor_tree_map(lambda t: t[0, ...], out)\n",
        "    out = tensor_tree_map(to_float, out)\n",
        "    batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)\n",
        "    out = tensor_tree_map(lambda x: np.array(x.cpu()), out)\n",
        "\n",
        "    plddt = out[\"plddt\"]\n",
        "    mean_plddt = np.mean(plddt)\n",
        "    plddt_b_factors = np.repeat(\n",
        "        plddt[..., None], residue_constants.atom_type_num, axis=-1\n",
        "    )\n",
        "    # TODO: , may need to reorder chains, based on entity_ids\n",
        "    cur_protein = protein.from_prediction(\n",
        "        features=batch, result=out, b_factors=plddt_b_factors\n",
        "    )\n",
        "    cur_save_name = (\n",
        "        f\"{cur_param_path_postfix}_{cur_seed}\"\n",
        "    )\n",
        "    plddts[cur_save_name] = str(mean_plddt)\n",
        "    if is_multimer:\n",
        "        ptms[cur_save_name] = str(np.mean(out[\"iptm+ptm\"]))\n",
        "    with open(os.path.join(output_dir, cur_save_name + '.pdb'), \"w\") as f:\n",
        "        f.write(protein.to_pdb(cur_protein))\n",
        "\n",
        "    if is_multimer:\n",
        "        mean_ptm = np.mean(out[\"iptm+ptm\"])\n",
        "        if mean_ptm>best_score:\n",
        "            best_protein = cur_protein\n",
        "            best_pae = out[\"predicted_aligned_error\"]\n",
        "            best_plddt = out[\"plddt\"]\n",
        "            best_score = mean_ptm\n",
        "    else:\n",
        "        if mean_plddt>best_score:\n",
        "            best_protein = cur_protein\n",
        "            best_plddt = out[\"plddt\"]\n",
        "            best_score = mean_plddt\n",
        "\n",
        "print(\"plddts\", plddts)\n",
        "score_name = f\"{model_name}_{cur_param_path_postfix}\"\n",
        "plddt_fname = score_name + \"_plddt.json\"\n",
        "json.dump(plddts, open(os.path.join(output_dir, plddt_fname), \"w\"), indent=4)\n",
        "if ptms:\n",
        "    print(\"ptms\", ptms)\n",
        "    ptm_fname = score_name + \"_ptm.json\"\n",
        "    json.dump(ptms, open(os.path.join(output_dir, ptm_fname), \"w\"), indent=4)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "kryWdmg0jZwT"
      },
      "outputs": [],
      "source": [
        "#@title Show the protein structure\n",
        "\n",
        "# Construct multiclass b-factors to indicate confidence bands\n",
        "# 0=very low, 1=low, 2=confident, 3=very high\n",
        "# Color bands for visualizing plddt\n",
        "import py3Dmol\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib.colors import LinearSegmentedColormap\n",
        "from IPython import display\n",
        "from ipywidgets import GridspecLayout\n",
        "from ipywidgets import Output\n",
        "\n",
        "\n",
        "show_sidechains = False #@param {type:\"boolean\"}\n",
        "dpi = 100 #@param {type:\"integer\"}\n",
        "\n",
        "to_visualize_pdb = protein.to_pdb(best_protein)\n",
        "\n",
        "PLDDT_BANDS = [(0., 0.50, '#FF7D45'),\n",
        "               (0.50, 0.70, '#FFDB13'),\n",
        "               (0.70, 0.90, '#65CBF3'),\n",
        "               (0.90, 1.00, '#0053D6')]\n",
        "\n",
        "\n",
        "# --- Visualise the prediction & confidence ---\n",
        "def plot_plddt_legend():\n",
        "    \"\"\"Plots the legend for pLDDT.\"\"\"\n",
        "    thresh = ['Very low (pLDDT < 50)',\n",
        "              'Low (70 > pLDDT > 50)',\n",
        "              'Confident (90 > pLDDT > 70)',\n",
        "              'Very high (pLDDT > 90)']\n",
        "\n",
        "    colors = [x[2] for x in PLDDT_BANDS]\n",
        "\n",
        "    plt.figure(figsize=(2, 2))\n",
        "    for c in colors:\n",
        "        plt.bar(0, 0, color=c)\n",
        "    plt.legend(thresh, frameon=False, loc='center', fontsize=20)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    ax = plt.gca()\n",
        "    ax.spines['right'].set_visible(False)\n",
        "    ax.spines['top'].set_visible(False)\n",
        "    ax.spines['left'].set_visible(False)\n",
        "    ax.spines['bottom'].set_visible(False)\n",
        "    plt.title('Model Confidence', fontsize=20, pad=20)\n",
        "    return plt\n",
        "\n",
        "\n",
        "if is_multimer:\n",
        "    multichain_view = py3Dmol.view(width=800, height=600)\n",
        "    multichain_view.addModelsAsFrames(to_visualize_pdb)\n",
        "    multichain_style = {'cartoon': {'colorscheme': 'chain'}}\n",
        "    multichain_view.setStyle({'model': -1}, multichain_style)\n",
        "    multichain_view.zoomTo()\n",
        "    multichain_view.show()\n",
        "\n",
        "# Color the structure by per-residue pLDDT\n",
        "view = py3Dmol.view(width=800, height=600)\n",
        "view.addModelsAsFrames(to_visualize_pdb)\n",
        "style = {'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}}\n",
        "if show_sidechains:\n",
        "    style['stick'] = {}\n",
        "view.setStyle({'model':-1}, style)\n",
        "view.zoomTo()\n",
        "\n",
        "grid = GridspecLayout(1, 2)\n",
        "out = Output()\n",
        "with out:\n",
        "    view.show()\n",
        "grid[0, 0] = out\n",
        "\n",
        "out = Output()\n",
        "with out:\n",
        "    plot_plddt_legend().show()\n",
        "grid[0, 1] = out\n",
        "\n",
        "display.display(grid)\n",
        "\n",
        "# Display pLDDT and predicted aligned error (if output by the model).\n",
        "if is_multimer:\n",
        "  num_plots = 2\n",
        "else:\n",
        "  num_plots = 1\n",
        "\n",
        "plt.figure(figsize=[8 * num_plots , 6])\n",
        "plt.subplot(1, num_plots, 1)\n",
        "plt.plot(plddt*100)\n",
        "plt.title('Predicted LDDT')\n",
        "plt.xlabel('Residue')\n",
        "plt.ylabel('pLDDT')\n",
        "plt.grid()\n",
        "plddt_svg_path = os.path.join(output_dir, 'plddt.svg')\n",
        "plt.savefig(plddt_svg_path, dpi=dpi, bbox_inches='tight')\n",
        "\n",
        "\n",
        "if num_plots == 2:\n",
        "    plt.subplot(1, 2, 2)\n",
        "    max_pae = np.max(best_pae)\n",
        "    colors = ['#0F006F','#245AE6','#55CCFF','#FFFFFF']\n",
        "\n",
        "    cmap = LinearSegmentedColormap.from_list('mymap', colors)\n",
        "    im = plt.imshow(best_pae, vmin=0., vmax=max_pae, cmap=cmap)\n",
        "    plt.colorbar(im, fraction=0.046, pad=0.04)\n",
        "\n",
        "    # Display lines at chain boundaries.\n",
        "    total_num_res = best_protein.residue_index.shape[-1]\n",
        "    chain_ids = best_protein.chain_index\n",
        "    for chain_boundary in np.nonzero(chain_ids[:-1] - chain_ids[1:]):\n",
        "        if chain_boundary.size:\n",
        "            plt.plot([0, total_num_res], [chain_boundary, chain_boundary], color='red')\n",
        "            plt.plot([chain_boundary, chain_boundary], [0, total_num_res], color='red')\n",
        "\n",
        "    plt.title('Predicted Aligned Error')\n",
        "    plt.xlabel('Scored residue')\n",
        "    plt.ylabel('Aligned residue')\n",
        "    pae_svg_path = os.path.join(output_dir, 'pae.svg')\n",
        "    plt.savefig(pae_svg_path, dpi=dpi, bbox_inches='tight')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form"
      },
      "outputs": [],
      "source": [
        "#@title Download the prediction\n",
        "#@markdown **The content of zip file**:\n",
        "#@markdown 1. PDB formatted structures\n",
        "#@markdown 2. Json file of the model quality (pLDDT and pTM for multimer)\n",
        "#@markdown 2. Plots of the model quality (pLDDT and PAE for multimer)\n",
        "\n",
        "from google.colab import files\n",
        "\n",
        "\n",
        "plddt_file = os.path.join(output_dir, plddt_fname)\n",
        "\n",
        "pdb_files = [os.path.join(output_dir, pdb_name + '.pdb') for pdb_name in plddts]\n",
        "file_lists = pdb_files + [\n",
        "    plddt_file, plddt_svg_path\n",
        "]\n",
        "if is_multimer:\n",
        "  ptm_file = os.path.join(output_dir, ptm_fname)\n",
        "  file_lists.append(ptm_file)\n",
        "  file_lists.append(pae_svg_path)\n",
        "\n",
        "!zip -q {target_id}.zip {\" \".join(file_lists)}\n",
        "files.download(f'{target_id}.zip')"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "unifold.ipynb",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3.8",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
